from loader import generate_dataset, Sets

from models import generate_model, SimpleNet

if __name__ == "__main__":
    x_train, x_test, y_train, y_test = Sets("../cifar-10-batches-py")()

    sets = {"train": generate_dataset(
        x_train, y_train, batch=512), "test": generate_dataset(x_test, y_test, batch=1024)}

    net = SimpleNet()
    generate_model(net, sets)

    # for x, y in sets['train']:
    #     print(y.shape)
    #     print(x.shape)
    #     out = net(x)

    #     print(out.shape)
    #     break